{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e7f697fec28a"
      },
      "source": [
        "##### Copyright 2022 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "GpAXuTgZ888M"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EFwSaNB8jF7s"
      },
      "source": [
        "\u003cstyle\u003e\n",
        "td {\n",
        "  text-align: center;\n",
        "}\n",
        "\n",
        "th {\n",
        "  text-align: center;\n",
        "}\n",
        "\u003c/style\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a934948f7030"
      },
      "source": [
        "# Neural machine translation with a Transformer and Keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a241496dc3d9"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/text/tutorials/transformer\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003e\n",
        "    View on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/transformer.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003e\n",
        "    Run in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/text/blob/master/docs/tutorials/transformer.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003e\n",
        "    View source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/text/docs/tutorials/transformer.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TCg3ElKBUSBb"
      },
      "source": [
        "This tutorial demonstrates how to create and train a [sequence-to-sequence](https://developers.google.com/machine-learning/glossary#sequence-to-sequence-task) [Transformer](https://developers.google.com/machine-learning/glossary#Transformer) model to translate [Portuguese into English](https://www.tensorflow.org/datasets/catalog/ted_hrlr_translate#ted_hrlr_translatept_to_en). The Transformer was originally proposed in [\"Attention is all you need\"](https://arxiv.org/abs/1706.03762) by Vaswani et al. (2017).\n",
        "\n",
        "Transformers are deep neural networks that replace CNNs and RNNs with [self-attention](https://developers.google.com/machine-learning/glossary#self-attention). Self-attention allows Transformers to easily transmit information across the input sequences.\n",
        "\n",
        "As explained in the [Google AI Blog post](https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html):\n",
        "\n",
        "\u003e Neural networks for machine translation typically contain an encoder reading the input sentence and generating a representation of it. A decoder then generates the output sentence word by word while consulting the representation generated by the encoder. The Transformer starts by generating initial representations, or embeddings, for each word... Then, using self-attention, it aggregates information from all of the other words, generating a new representation per word informed by the entire context, represented by the filled balls. This step is then repeated multiple times in parallel for all words, successively generating new representations."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fo1P7AN4lzdi"
      },
      "source": [
        "\u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/apply_the_transformer_to_machine_translation.gif\" alt=\"Applying the Transformer to machine translation\"\u003e\n",
        "\n",
        "Figure 1: Applying the Transformer to machine translation. Source: [Google AI Blog](https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RAxfGTJJYbQi"
      },
      "source": [
        "That's a lot to digest, the goal of this tutorial is to break it down into easy to understand parts. In this tutorial you will:\n",
        "\n",
        "- Prepare the data.\n",
        "- Implement necessary components:\n",
        "  - Positional embeddings.\n",
        "  - Attention layers.\n",
        "  - The encoder and decoder.\n",
        "- Build \u0026 train the Transformer.\n",
        "- Generate translations.\n",
        "- Export the model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ddvOmAXhaDXt"
      },
      "source": [
        "To get the most out of this tutorial, it helps if you know about [the basics of text generation](./text_generation.ipynb) and attention mechanisms. \n",
        "\n",
        "A Transformer is a sequence-to-sequence encoder-decoder model similar to the model in the [NMT with attention tutorial](https://www.tensorflow.org/text/tutorials/nmt_with_attention).\n",
        "A single-layer Transformer takes a little more code to write, but is almost identical to that encoder-decoder RNN model. The only difference is that the RNN layers are replaced with self-attention layers.\n",
        "This tutorial builds a 4-layer Transformer which is larger and more powerful, but not fundamentally more complex."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jk40oPm8OD51"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth\u003eThe \u003ca href=https://www.tensorflow.org/text/tutorials/nmt_with_attention\u003eRNN+Attention model\u003c/a\u003e\u003c/th\u003e\n",
        "  \u003cth\u003eA 1-layer transformer\u003c/th\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=411 src=\"https://www.tensorflow.org/images/tutorials/transformer/RNN+attention-words.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=400 src=\"https://www.tensorflow.org/images/tutorials/transformer/Transformer-1layer-words.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "huJ97Eh-Ue4V"
      },
      "source": [
        "After training the model in this notebook, you will be able to input a Portuguese sentence and return the English translation.\n",
        "\n",
        "\u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/attention_map_portuguese.png\" alt=\"Attention heatmap\"\u003e\n",
        "\n",
        "Figure 2: Visualized attention weights that you can generate at the end of this tutorial."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aZL6uTTE5137"
      },
      "source": [
        "## Why Transformers are significant\n",
        "\n",
        "- Transformers excel at modeling sequential data, such as natural language.\n",
        "- Unlike [recurrent neural networks (RNNs)](./text_generation.ipynb), Transformers are parallelizable. This makes them efficient on hardware like GPUs and TPUs. The main reasons is that Transformers replaced recurrence with attention, and computations can happen simultaneously. Layer outputs can be computed in parallel, instead of a series like an RNN.\n",
        "- Unlike [RNNs](https://www.tensorflow.org/guide/keras/rnn) (such as [seq2seq, 2014](https://arxiv.org/abs/1409.3215)) or [convolutional neural networks (CNNs)](https://www.tensorflow.org/tutorials/images/cnn) (for example, [ByteNet](https://arxiv.org/abs/1610.10099)), Transformers are able to capture distant or long-range contexts and dependencies in the data between distant positions in the input or output sequences. Thus, longer connections can be learned. Attention allows each location to have access to the entire input at each layer, while in RNNs and CNNs, the information needs to pass through many processing steps to move a long distance, which makes it harder to learn.\n",
        "- Transformers make no assumptions about the temporal/spatial relationships across the data. This is ideal for processing a set of objects (for example, [StarCraft units](https://www.deepmind.com/blog/alphastar-mastering-the-real-time-strategy-game-starcraft-ii)).\n",
        "\n",
        "\u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/encoder_self_attention_distribution.png\" width=\"800\" alt=\"Encoder self-attention distribution for the word it from the 5th to the 6th layer of a Transformer trained on English-to-French translation\"\u003e\n",
        "\n",
        "Figure 3: The encoder self-attention distribution for the word “it” from the 5th to the 6th layer of a Transformer trained on English-to-French translation (one of eight attention heads). Source: [Google AI Blog](https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "swymtxpl7W7w"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OfV1batAwq9j"
      },
      "source": [
        "Begin by installing [TensorFlow Datasets](https://tensorflow.org/datasets) for loading the dataset and [TensorFlow Text](https://www.tensorflow.org/text) for text preprocessing:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XFG0NDRu5mYQ"
      },
      "outputs": [],
      "source": [
        "# Install the most re version of TensorFlow to use the improved\n",
        "# masking support for `tf.keras.layers.MultiHeadAttention`.\n",
        "!apt install --allow-change-held-packages libcudnn8=8.1.0.77-1+cuda11.2\n",
        "!pip uninstall -y -q tensorflow keras tensorflow-estimator tensorflow-text\n",
        "!pip install protobuf~=3.20.3\n",
        "!pip install -q tensorflow_datasets\n",
        "!pip install -q -U tensorflow-text tensorflow"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0GYpLBSjxJmG"
      },
      "source": [
        "Import the necessary modules:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JjJJyJTZYebt"
      },
      "outputs": [],
      "source": [
        "import logging\n",
        "import time\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow as tf\n",
        "\n",
        "import tensorflow_text"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xf_WUi2HLhzf"
      },
      "source": [
        "## Data handling\n",
        "\n",
        "This section downloads the dataset and the subword tokenizer, from [this tutorial](https://www.tensorflow.org/text/guide/subwords_tokenizer), then wraps it all up in a `tf.data.Dataset` for training.\n",
        "\n",
        " \u003csection class=\"expandable tfo-display-only-on-site\"\u003e\n",
        " \u003cbutton type=\"button\" class=\"button-red button expand-control\"\u003eToggle section\u003c/button\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-cCvXbPkccV1"
      },
      "source": [
        "### Download the dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LTEVgBxklzdq"
      },
      "source": [
        "Use TensorFlow Datasets to load the [Portuguese-English translation dataset](https://www.tensorflow.org/datasets/catalog/ted_hrlr_translate#ted_hrlr_translatept_to_en)D Talks Open Translation Project. This dataset contains approximately 52,000 training, 1,200 validation and 1,800 test examples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8q9t4FmN96eN"
      },
      "outputs": [],
      "source": [
        "examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en',\n",
        "                               with_info=True,\n",
        "                               as_supervised=True)\n",
        "\n",
        "train_examples, val_examples = examples['train'], examples['validation']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZA4cw7F_DmSt"
      },
      "source": [
        "The `tf.data.Dataset` object returned by TensorFlow Datasets yields pairs of text examples:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CZFAMZJyDrFn"
      },
      "outputs": [],
      "source": [
        "for pt_examples, en_examples in train_examples.batch(3).take(1):\n",
        "  print('\u003e Examples in Portuguese:')\n",
        "  for pt in pt_examples.numpy():\n",
        "    print(pt.decode('utf-8'))\n",
        "  print()\n",
        "\n",
        "  print('\u003e Examples in English:')\n",
        "  for en in en_examples.numpy():\n",
        "    print(en.decode('utf-8'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eJxTd6aVnZyh"
      },
      "source": [
        "### Set up the tokenizer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mopr6oKUlzds"
      },
      "source": [
        "Now that you have loaded the dataset, you need to tokenize the text, so that each element is represented as a [token](https://developers.google.com/machine-learning/glossary#token) or token ID (a numeric representation).\n",
        "\n",
        "Tokenization is the process of breaking up text, into \"tokens\". Depending on the tokenizer, these tokens can represent sentence-pieces, words, subwords, or characters. To learn more about tokenization, visit [this guide](https://www.tensorflow.org/text/guide/tokenizers)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GJr_8Jz9FKgu"
      },
      "source": [
        "This tutorial uses the tokenizers built in the [subword tokenizer](https://www.tensorflow.org/text/guide/subwords_tokenizer) tutorial. That tutorial optimizes two `text.BertTokenizer` objects (one for English, one for Portuguese) for **this dataset** and exports them in a TensorFlow `saved_model` format.\n",
        "\n",
        "\u003e Note: This is different from the [original paper](https://arxiv.org/pdf/1706.03762.pdf), section 5.1, where they used a single byte-pair tokenizer for both the source and target with a vocabulary-size of 37000.\n",
        "\n",
        "Download, extract, and import the `saved_model`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QToMl0NanZPr"
      },
      "outputs": [],
      "source": [
        "model_name = 'ted_hrlr_translate_pt_en_converter'\n",
        "tf.keras.utils.get_file(\n",
        "    f'{model_name}.zip',\n",
        "    f'https://storage.googleapis.com/download.tensorflow.org/models/{model_name}.zip',\n",
        "    cache_dir='.', cache_subdir='', extract=True\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h5dbGnPXnuI1"
      },
      "outputs": [],
      "source": [
        "tokenizers = tf.saved_model.load(model_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CexgkIS1lzdt"
      },
      "source": [
        "The `tf.saved_model` contains two text tokenizers, one for English and one for Portuguese. Both have the same methods:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s-PCJijfcZ9_"
      },
      "outputs": [],
      "source": [
        "[item for item in dir(tokenizers.en) if not item.startswith('_')]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fUBljDDEFWUC"
      },
      "source": [
        "The `tokenize` method converts a batch of strings to a padded-batch of token IDs. This method splits punctuation, lowercases and unicode-normalizes the input before tokenizing. That standardization is not visible here because the input data is already standardized."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z_gPC5iwFXfU"
      },
      "outputs": [],
      "source": [
        "print('\u003e This is a batch of strings:')\n",
        "for en in en_examples.numpy():\n",
        "  print(en.decode('utf-8'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uSkM7z8JFaVO"
      },
      "outputs": [],
      "source": [
        "encoded = tokenizers.en.tokenize(en_examples)\n",
        "\n",
        "print('\u003e This is a padded-batch of token IDs:')\n",
        "for row in encoded.to_list():\n",
        "  print(row)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nBkv7XeBFa8_"
      },
      "source": [
        "The `detokenize` method attempts to convert these token IDs back to human-readable text: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-CFS5aAxFdpP"
      },
      "outputs": [],
      "source": [
        "round_trip = tokenizers.en.detokenize(encoded)\n",
        "\n",
        "print('\u003e This is human-readable text:')\n",
        "for line in round_trip.numpy():\n",
        "  print(line.decode('utf-8'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G-2gMSBBU-AE"
      },
      "source": [
        "The lower level `lookup` method converts from token-IDs to token text:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XaCeOnswVAhI"
      },
      "outputs": [],
      "source": [
        "print('\u003e This is the text split into tokens:')\n",
        "tokens = tokenizers.en.lookup(encoded)\n",
        "tokens"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pR3vZJf1Yhg_"
      },
      "source": [
        "The output demonstrates the \"subword\" aspect of the subword tokenization.\n",
        "\n",
        "For example, the word `'searchability'` is decomposed into `'search'` and `'##ability'`, and the word `'serendipity'` into `'s'`, `'##ere'`, `'##nd'`, `'##ip'` and `'##ity'`.\n",
        "\n",
        "Note that the tokenized text includes `'[START]'` and `'[END]'` tokens."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g_4vdnhSaATh"
      },
      "source": [
        "The distribution of tokens per example in the dataset is as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KRbke-iaaHFI"
      },
      "outputs": [],
      "source": [
        "lengths = []\n",
        "\n",
        "for pt_examples, en_examples in train_examples.batch(1024):\n",
        "  pt_tokens = tokenizers.pt.tokenize(pt_examples)\n",
        "  lengths.append(pt_tokens.row_lengths())\n",
        "  \n",
        "  en_tokens = tokenizers.en.tokenize(en_examples)\n",
        "  lengths.append(en_tokens.row_lengths())\n",
        "  print('.', end='', flush=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9ucA1q3GaK_n"
      },
      "outputs": [],
      "source": [
        "all_lengths = np.concatenate(lengths)\n",
        "\n",
        "plt.hist(all_lengths, np.linspace(0, 500, 101))\n",
        "plt.ylim(plt.ylim())\n",
        "max_length = max(all_lengths)\n",
        "plt.plot([max_length, max_length], plt.ylim())\n",
        "plt.title(f'Maximum tokens per example: {max_length}');"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-Yb35sTJcZq9"
      },
      "source": [
        "### Set up a data pipeline with `tf.data`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JZHsns5obJhN"
      },
      "source": [
        "The following function takes batches of text as input, and converts them to a format suitable for training. \n",
        "\n",
        "1. It tokenizes them into ragged batches.\n",
        "2. It trims each to be no longer than `MAX_TOKENS`.\n",
        "3. It splits the target (English) tokens into inputs and labels. These are shifted by one step so that at each input location the `label` is the id of the next token.\n",
        "4. It converts the `RaggedTensor`s to padded dense `Tensor`s.\n",
        "5. It returns an `(inputs, labels)` pair.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6shgzEck3FiV"
      },
      "outputs": [],
      "source": [
        "MAX_TOKENS=128\n",
        "def prepare_batch(pt, en):\n",
        "    pt = tokenizers.pt.tokenize(pt)      # Output is ragged.\n",
        "    pt = pt[:, :MAX_TOKENS]    # Trim to MAX_TOKENS.\n",
        "    pt = pt.to_tensor()  # Convert to 0-padded dense Tensor\n",
        "\n",
        "    en = tokenizers.en.tokenize(en)\n",
        "    en = en[:, :(MAX_TOKENS+1)]\n",
        "    en_inputs = en[:, :-1].to_tensor()  # Drop the [END] tokens\n",
        "    en_labels = en[:, 1:].to_tensor()   # Drop the [START] tokens\n",
        "\n",
        "    return (pt, en_inputs), en_labels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dAroQ6xelzdx"
      },
      "source": [
        "The function below converts a dataset of text examples into data of batches for training. \n",
        "\n",
        "1. It tokenizes the text, and filters out the sequences that are too long.\n",
        "   (The `batch`/`unbatch` is included because the tokenizer is much more efficient on large batches).\n",
        "2. The `cache` method ensures that that work is only executed once.\n",
        "3. Then `shuffle` and, `dense_to_ragged_batch` randomize the order and assemble batches of examples. \n",
        "4. Finally `prefetch` runs the dataset in parallel with the model to ensure that data is available when needed. See [Better performance with the `tf.data`](https://www.tensorflow.org/guide/data_performance.ipynb) for details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bcRp7VcQ5m6g"
      },
      "outputs": [],
      "source": [
        "BUFFER_SIZE = 20000\n",
        "BATCH_SIZE = 64"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BUN_jLBTwNxk"
      },
      "outputs": [],
      "source": [
        "def make_batches(ds):\n",
        "  return (\n",
        "      ds\n",
        "      .shuffle(BUFFER_SIZE)\n",
        "      .batch(BATCH_SIZE)\n",
        "      .map(prepare_batch, tf.data.AUTOTUNE)\n",
        "      .prefetch(buffer_size=tf.data.AUTOTUNE))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FX_h3tCnwgR4"
      },
      "source": [
        " \u003c/section\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "itSWqk-ivrRg"
      },
      "source": [
        "## Test the Dataset "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BSswr5TKvoNM"
      },
      "outputs": [],
      "source": [
        "# Create training and validation set batches.\n",
        "train_batches = make_batches(train_examples)\n",
        "val_batches = make_batches(val_examples)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PSufllC7wooA"
      },
      "source": [
        "The resulting `tf.data.Dataset` objects are setup for training with Keras.\n",
        "Keras `Model.fit` training expects `(inputs, labels)` pairs.\n",
        "The `inputs` are pairs of tokenized Portuguese and English sequences, `(pt, en)`.\n",
        "The `labels` are the same English sequences shifted by 1.\n",
        "This shift is so that at each location input `en` sequence, the `label` in the next token.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JJdJttsF751"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth\u003eInputs at the bottom, labels at the top.\u003c/th\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=400 src=\"https://www.tensorflow.org/images/tutorials/transformer/Transformer-1layer-words.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tsF751JJdJt"
      },
      "source": [
        "This is the same as the [text generation tutorial](text_generation.ipynb),\n",
        "except here you have additional input \"context\" (the Portuguese sequence) that the model is \"conditioned\" on.\n",
        "\n",
        "This setup is called \"teacher forcing\" because regardless of the model's output at each timestep, it gets the true value as input for the next timestep.\n",
        "This is a simple and efficient way to train a text generation model.\n",
        "It's efficient because you don't need to run the model sequentially, the outputs at the different sequence locations can be computed in parallel.\n",
        "\n",
        "You might have expected the `input, output`, pairs to simply be the `Portuguese, English` sequences.\n",
        "Given the Portuguese sequence, the model would try to generate the English sequence.\n",
        "\n",
        "It's possible to train a model that way. You'd need to write out the inference loop and pass the model's output back to the input.\n",
        "It's slower (time steps can't run in parallel), and a harder task to learn (the model can't get the end of a sentence right until it gets the beginning right),\n",
        "but it can give a more stable model because the model has to learn to correct its own errors during training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CAw2XjRwLFWr"
      },
      "outputs": [],
      "source": [
        "for (pt, en), en_labels in train_batches.take(1):\n",
        "  break\n",
        "\n",
        "print(pt.shape)\n",
        "print(en.shape)\n",
        "print(en_labels.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tzo3JKaqx46g"
      },
      "source": [
        "The `en` and `en_labels` are the same, just shifted by 1:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "apFeC-WWxzR4"
      },
      "outputs": [],
      "source": [
        "print(en[0][:10])\n",
        "print(en_labels[0][:10])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7e7hKcxn6-zd"
      },
      "source": [
        "## Define the components"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HVE5j6JlcAps"
      },
      "source": [
        "There's a lot going on inside a Transformer. The important things to remember are:\n",
        "\n",
        "1. It follows the same general pattern as a standard sequence-to-sequence model with an encoder and a decoder.\n",
        "2. If you work through it step by step it will all make sense."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O0R4bYJ0DiFR"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe original Transformer diagram\u003c/th\u003e\n",
        "  \u003cth colspan=1\u003eA representation of a 4-layer Transformer\u003c/th\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=400 src=\"https://www.tensorflow.org/images/tutorials/transformer/transformer.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=307 src=\"https://www.tensorflow.org/images/tutorials/transformer/Transformer-4layer-compact.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e\n",
        "\n",
        "Each of the components in these two diagrams will be explained as you progress through the tutorial."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YS75Y-9-lkzn"
      },
      "source": [
        "### The embedding and positional encoding layer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "26l90xiq3Nis"
      },
      "source": [
        "The inputs to both the encoder and decoder use the same embedding and positional encoding logic. \n",
        "\n",
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe embedding and positional encoding layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/PositionalEmbedding.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "279u2DiDlmdS"
      },
      "source": [
        "Given a sequence of tokens, both the input tokens (Portuguese) and target tokens (English) have to be converted to vectors using a `tf.keras.layers.Embedding` layer.\n",
        "\n",
        "The attention layers used throughout the model see their input as a set of vectors, with no order. Since the model doesn't contain any recurrent or convolutional layers. It needs some way to identify word order, otherwise it would see the input sequence as a [bag of words](https://developers.google.com/machine-learning/glossary#bag-of-words) instance, `how are you`, `how you are`, `you how are`, and so on, are indistinguishable.\n",
        "\n",
        "A Transformer adds a \"Positional Encoding\" to the embedding vectors. It uses a set of sines and cosines at different frequencies (across the sequence). By definition nearby elements will have similar position encodings."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4gcCNZP7lzdy"
      },
      "source": [
        "The original paper uses the following formula for calculating the positional encoding:\n",
        "\n",
        "$$\\Large{PE_{(pos, 2i)} = \\sin(pos / 10000^{2i / d_{model}})} $$\n",
        "$$\\Large{PE_{(pos, 2i+1)} = \\cos(pos / 10000^{2i / d_{model}})} $$\n",
        "\n",
        "Note: The code below implements it, but instead of interleaving the sines and cosines, the vectors of sines and cosines are simply concatenated. Permuting the channels like this is functionally equivalent, and just a little easier to implement and show in the plots below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1Rz82wEs5biZ"
      },
      "outputs": [],
      "source": [
        "def positional_encoding(length, depth):\n",
        "  depth = depth/2\n",
        "\n",
        "  positions = np.arange(length)[:, np.newaxis]     # (seq, 1)\n",
        "  depths = np.arange(depth)[np.newaxis, :]/depth   # (1, depth)\n",
        "  \n",
        "  angle_rates = 1 / (10000**depths)         # (1, depth)\n",
        "  angle_rads = positions * angle_rates      # (pos, depth)\n",
        "\n",
        "  pos_encoding = np.concatenate(\n",
        "      [np.sin(angle_rads), np.cos(angle_rads)],\n",
        "      axis=-1) \n",
        "\n",
        "  return tf.cast(pos_encoding, dtype=tf.float32)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ra1IcbzFhnmF"
      },
      "source": [
        "The position encoding function is a stack of sines and cosines that vibrate at different frequencies depending on their location along the depth of the embedding vector. They vibrate across the position axis."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AKf4Ky2dhg0L"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "pos_encoding = positional_encoding(length=2048, depth=512)\n",
        "\n",
        "# Check the shape.\n",
        "print(pos_encoding.shape)\n",
        "\n",
        "# Plot the dimensions.\n",
        "plt.pcolormesh(pos_encoding.numpy().T, cmap='RdBu')\n",
        "plt.ylabel('Depth')\n",
        "plt.xlabel('Position')\n",
        "plt.colorbar()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eKqVkl9Jlzg6"
      },
      "source": [
        "By definition these vectors align well with nearby vectors along the position axis. Below the position encoding vectors are normalized and the vector from position `1000` is compared, by dot-product, to all the others:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CXY-8_uEhcRD"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "pos_encoding/=tf.norm(pos_encoding, axis=1, keepdims=True)\n",
        "p = pos_encoding[1000]\n",
        "dots = tf.einsum('pd,d -\u003e p', pos_encoding, p)\n",
        "plt.subplot(2,1,1)\n",
        "plt.plot(dots)\n",
        "plt.ylim([0,1])\n",
        "plt.plot([950, 950, float('nan'), 1050, 1050],\n",
        "         [0,1,float('nan'),0,1], color='k', label='Zoom')\n",
        "plt.legend()\n",
        "plt.subplot(2,1,2)\n",
        "plt.plot(dots)\n",
        "plt.xlim([950, 1050])\n",
        "plt.ylim([0,1])\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LUknPLlVm99o"
      },
      "source": [
        "So use this to create a `PositionEmbedding` layer that looks-up a token's embedding vector and adds the position vector:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "838tmM1cm9cB"
      },
      "outputs": [],
      "source": [
        "class PositionalEmbedding(tf.keras.layers.Layer):\n",
        "  def __init__(self, vocab_size, d_model):\n",
        "    super().__init__()\n",
        "    self.d_model = d_model\n",
        "    self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True) \n",
        "    self.pos_encoding = positional_encoding(length=2048, depth=d_model)\n",
        "\n",
        "  def compute_mask(self, *args, **kwargs):\n",
        "    return self.embedding.compute_mask(*args, **kwargs)\n",
        "\n",
        "  def call(self, x):\n",
        "    length = tf.shape(x)[1]\n",
        "    x = self.embedding(x)\n",
        "    # This factor sets the relative scale of the embedding and positonal_encoding.\n",
        "    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))\n",
        "    x = x + self.pos_encoding[tf.newaxis, :length, :]\n",
        "    return x\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QpWnjwygmw-x"
      },
      "source": [
        "\u003e Note: The [original paper](https://arxiv.org/pdf/1706.03762.pdf), section 3.4 and 5.1, uses a single tokenizer and weight matrix for both the source and target languages. This tutorial uses two separate tokenizers and weight matrices."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tfz-EaCEDfUJ"
      },
      "outputs": [],
      "source": [
        "embed_pt = PositionalEmbedding(vocab_size=tokenizers.pt.get_vocab_size().numpy(), d_model=512)\n",
        "embed_en = PositionalEmbedding(vocab_size=tokenizers.en.get_vocab_size().numpy(), d_model=512)\n",
        "\n",
        "pt_emb = embed_pt(pt)\n",
        "en_emb = embed_en(en)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3fJZ_ArLELhJ"
      },
      "outputs": [],
      "source": [
        "en_emb._keras_mask"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mE9cEBWCMKOP"
      },
      "source": [
        "### Add and normalize\n",
        "\n",
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=2\u003eAdd and normalize\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/Add+Norm.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lfz3WjFLTEk_"
      },
      "source": [
        "These \"Add \u0026 Norm\" blocks are scattered throughout the model. Each one joins a residual connection and runs the result through a `LayerNormalization` layer.\n",
        "\n",
        "The easiest way to organize the code is around these residual blocks. The following sections will define custom layer classes for each. \n",
        "\n",
        "The residual \"Add \u0026 Norm\" blocks are included so that training is efficient. The residual connection provides a direct path for the gradient (and ensures that vectors are **updated** by the attention layers instead of **replaced**), while the normalization maintains a reasonable scale for the outputs.\n",
        "\n",
        "Note: The implementations, below, use the `Add` layer to ensure that Keras masks are propagated (the `+` operator does not).\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vJAJ2_VlPXrZ"
      },
      "source": [
        "### The base attention layer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1tMGOGki35KI"
      },
      "source": [
        "Attention layers are used throughout the model. These are all identical except for how the attention is configured. Each one contains a `layers.MultiHeadAttention`, a `layers.LayerNormalization` and a `layers.Add`. \n",
        "\n",
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=2\u003eThe base attention layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/BaseAttention.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z6chjIrOVSYp"
      },
      "source": [
        "To implement these attention layers, start with a simple base class that just contains the component layers. Each use-case will be implemented as a subclass. It's a little more code to write this way, but it keeps the intention clear."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5VLa5QcdPpv5"
      },
      "outputs": [],
      "source": [
        "class BaseAttention(tf.keras.layers.Layer):\n",
        "  def __init__(self, **kwargs):\n",
        "    super().__init__()\n",
        "    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)\n",
        "    self.layernorm = tf.keras.layers.LayerNormalization()\n",
        "    self.add = tf.keras.layers.Add()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wBY06TCqV2lv"
      },
      "source": [
        "#### Attention refresher\n",
        "\n",
        "Before you get into the specifics of each usage, here is a quick refresher on how attention works:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5BsRsq4TV5FY"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe base attention layer\u003c/th\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=430 src=\"https://www.tensorflow.org/images/tutorials/transformer/BaseAttention-new.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AtGTy7vZ5aaT"
      },
      "source": [
        "There are two inputs:\n",
        "\n",
        "1. The query sequence; the sequence being processed; the sequence doing the attending (bottom).\n",
        "2. The context sequence; the sequence being attended to (left).\n",
        "\n",
        "The output has the same shape as the query-sequence.\n",
        "\n",
        "The common comparison is that this operation is like a dictionary lookup.\n",
        "A **fuzzy**, **differentiable**, **vectorized** dictionary lookup.\n",
        "\n",
        "Here's a regular python dictionary, with 3 keys and 3 values being passed a single query.\n",
        "\n",
        "```\n",
        "d = {'color': 'blue', 'age': 22, 'type': 'pickup'}\n",
        "result = d['color']\n",
        "```\n",
        "\n",
        "- The `query`s is what you're trying to find.\n",
        "- The `key`s what sort of information the dictionary has.\n",
        "- The `value` is that information.\n",
        "\n",
        "When you look up a `query` in a regular dictionary, the dictionary finds the matching `key`, and returns its associated `value`.\n",
        "The `query` either has a matching `key` or it doesn't.\n",
        "You can imagine a **fuzzy** dictionary where the keys don't have to match perfectly.\n",
        "If you looked up `d[\"species\"]` in the dictionary above, maybe you'd want it to return `\"pickup\"` since that's the best match for the query.\n",
        "\n",
        "An attention layer does a fuzzy lookup like this, but it's not just looking for the best key.\n",
        "It combines the `values` based on how well the `query` matches each `key`.\n",
        "\n",
        "How does that work? In an attention layer the `query`, `key`, and `value` are each vectors.\n",
        "Instead of doing a hash lookup the attention layer combines the `query` and `key` vectors to determine how well they match, the \"attention score\".\n",
        "The layer returns the average across all the `values`, weighted by the \"attention scores\".\n",
        "\n",
        "Each location the query-sequence provides a `query` vector.\n",
        "The context sequence acts as the dictionary. At each location in the context sequence provides a `key` and `value` vector.\n",
        "The input vectors are not used directly, the `layers.MultiHeadAttention` layer includes `layers.Dense` layers to project the input vectors before using them.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B7QcPJvmv6ix"
      },
      "source": [
        "### The cross attention layer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o8VJZqds37QC"
      },
      "source": [
        "At the literal center of the Transformer is the cross-attention layer. This layer connects the encoder and decoder. This layer is the most straight-forward use of attention in the model, it performs the same task as the attention block in the [NMT with attention tutorial](https://www.tensorflow.org/text/tutorials/nmt_with_attention).\n",
        "\n",
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe cross attention layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/CrossAttention.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jhscgMUNUFWP"
      },
      "source": [
        "To implement this you pass the target sequence `x` as the `query` and the `context` sequence as the `key/value` when calling the `mha` layer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kfHVbJUWv8qp"
      },
      "outputs": [],
      "source": [
        "class CrossAttention(BaseAttention):\n",
        "  def call(self, x, context):\n",
        "    attn_output, attn_scores = self.mha(\n",
        "        query=x,\n",
        "        key=context,\n",
        "        value=context,\n",
        "        return_attention_scores=True)\n",
        "   \n",
        "    # Cache the attention scores for plotting later.\n",
        "    self.last_attn_scores = attn_scores\n",
        "\n",
        "    x = self.add([x, attn_output])\n",
        "    x = self.layernorm(x)\n",
        "\n",
        "    return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j3tJU6aTYY1X"
      },
      "source": [
        "The caricature below shows how information flows through this layer. The columns represent the weighted sum over the context sequence.\n",
        "\n",
        "For simplicity the residual connections are not shown."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VBE5JNB26OjJ"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth\u003eThe cross attention layer\u003c/th\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=430 src=\"https://www.tensorflow.org/images/tutorials/transformer/CrossAttention-new-full.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sU9MeSvzZSA-"
      },
      "source": [
        "The output length is the length of the `query` sequence, and not the length of the context `key/value` sequence.\n",
        "\n",
        "The diagram is further simplified, below. There's no need to draw the entire \"Attention weights\" matrix.\n",
        "The point is that each `query` location can see all the `key/value` pairs in the context, but no information is exchanged between the queries."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GRrB_GcyKv-4"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth\u003eEach query sees the whole context.\u003c/th\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=430 src=\"https://www.tensorflow.org/images/tutorials/transformer/CrossAttention-new.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BCQsj7ljKv-4"
      },
      "source": [
        "Test run it on sample inputs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qw1FJV5qRk79"
      },
      "outputs": [],
      "source": [
        "sample_ca = CrossAttention(num_heads=2, key_dim=512)\n",
        "\n",
        "print(pt_emb.shape)\n",
        "print(en_emb.shape)\n",
        "print(sample_ca(en_emb, pt_emb).shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J6qrQxSpv34R"
      },
      "source": [
        "### The global self-attention layer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z-LbLRTkaTh5"
      },
      "source": [
        "This layer is responsible for processing the context sequence, and propagating information along its length:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YlYBQX3E388Y"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe global self-attention layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/SelfAttention.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w9j9LPJFbEkF"
      },
      "source": [
        "Since the context sequence is fixed while the translation is being generated, information is allowed to flow in both directions. \n",
        "\n",
        "Before Transformers and self-attention, models commonly used RNNs or CNNs to do this task:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "87Rlu8N_avBF"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eBidirectional RNNs and CNNs\u003c/th\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=500 src=\"https://www.tensorflow.org/images/tutorials/transformer/RNN-bidirectional.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=500 src=\"https://www.tensorflow.org/images/tutorials/transformer/CNN.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PPyXM4vabhup"
      },
      "source": [
        "RNNs and CNNs have their limitations.\n",
        "\n",
        "- The RNN allows information to flow all the way across the sequence, but it passes through many processing steps to get there (limiting gradient flow). These RNN steps have to be run sequentially and so the RNN is less able to take advantage of modern parallel devices.\n",
        "- In the CNN each location can be processed in parallel, but it only provides a limited receptive field. The receptive field only grows linearly with the number of CNN layers,  You need to stack a number of Convolution layers to transmit information across the sequence ([Wavenet](https://arxiv.org/abs/1609.03499) reduces this problem by using dilated convolutions).\n",
        "\n",
        "The global self-attention layer on the other hand lets every sequence element directly access every other sequence element, with only a few operations, and all the outputs can be computed in parallel. \n",
        "\n",
        "To implement this layer you just need to pass the target sequence, `x`, as both the `query`, and `value` arguments to the `mha` layer: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RNqoTpn1wB3i"
      },
      "outputs": [],
      "source": [
        "class GlobalSelfAttention(BaseAttention):\n",
        "  def call(self, x):\n",
        "    attn_output = self.mha(\n",
        "        query=x,\n",
        "        value=x,\n",
        "        key=x)\n",
        "    x = self.add([x, attn_output])\n",
        "    x = self.layernorm(x)\n",
        "    return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jPn2D07-Jcmj"
      },
      "outputs": [],
      "source": [
        "sample_gsa = GlobalSelfAttention(num_heads=2, key_dim=512)\n",
        "\n",
        "print(pt_emb.shape)\n",
        "print(sample_gsa(pt_emb).shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nd-ga2tQfzhE"
      },
      "source": [
        "Sticking with the same style as before you could draw it like this:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F1bcv9Zc6--k"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe global self-attention layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=330 src=\"https://www.tensorflow.org/images/tutorials/transformer/SelfAttention-new-full.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ze7D0WHOe-d8"
      },
      "source": [
        "Again, the residual connections are omitted for clarity.\n",
        "\n",
        "It's more compact, and just as accurate to draw it like this:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "imlyNt2K7RnA"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe global self-attention layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=500 src=\"https://www.tensorflow.org/images/tutorials/transformer/SelfAttention-new.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yq4NtLymD99-"
      },
      "source": [
        "### The causal self-attention layer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VufkgF7caLze"
      },
      "source": [
        "This layer does a similar job as the global self-attention layer, for the output sequence:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3KMEDiP63-hQ"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe causal self-attention layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/CausalSelfAttention.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0AtF1HYFEOYf"
      },
      "source": [
        "This needs to be handled differently from the encoder's global self-attention layer.  \n",
        "\n",
        "Like the [text generation tutorial](https://www.tensorflow.org/text/tutorials/text_generation), and the [NMT with attention](https://www.tensorflow.org/text/tutorials/nmt_with_attention) tutorial, Transformers are an \"autoregressive\" model: They generate the text one token at a time and feed that output back to the input. To make this _efficient_, these models ensure that the output for each sequence element only depends on the previous sequence elements; the models are \"causal\"."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CDyn29oahHiL"
      },
      "source": [
        "A single-direction RNN is causal by definition. To make a causal convolution you just need to pad the input and shift the output so that it aligns correctly (use `layers.Conv1D(padding='causal')`) ."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9_1yd-LjhM3b"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eCausal RNNs and CNNs\u003c/th\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=500 src=\"https://www.tensorflow.org/images/tutorials/transformer/RNN.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=500 src=\"https://www.tensorflow.org/images/tutorials/transformer/CNN-causal.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4St1AWq9bIZg"
      },
      "source": [
        "A causal model is efficient in two ways: \n",
        "\n",
        "1. In training, it lets you compute loss for every location in the output sequence while executing the model just once.\n",
        "2. During inference, for each new token generated you only need to calculate its outputs, the outputs for the previous sequence elements can be reused.\n",
        "  - For an RNN you just need the RNN-state to account for previous computations (pass `return_state=True` to the RNN layer's constructor).\n",
        "  - For a CNN you would need to follow the approach of [Fast Wavenet](https://arxiv.org/abs/1611.09482)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WLYfIa8eiYgk"
      },
      "source": [
        "To build a causal self-attention layer, you need to use an appropriate mask when computing the attention scores and summing the attention `value`s.\n",
        "\n",
        "This is taken care of automatically if you pass `use_causal_mask = True` to the `MultiHeadAttention` layer when you call it:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4MMQ-AfKD99_"
      },
      "outputs": [],
      "source": [
        "class CausalSelfAttention(BaseAttention):\n",
        "  def call(self, x):\n",
        "    attn_output = self.mha(\n",
        "        query=x,\n",
        "        value=x,\n",
        "        key=x,\n",
        "        use_causal_mask = True)\n",
        "    x = self.add([x, attn_output])\n",
        "    x = self.layernorm(x)\n",
        "    return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C5oumNdAjI-D"
      },
      "source": [
        "The causal mask ensures that each location only has access to the locations that come before it: "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aFJy5L1U8TBt"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe causal self-attention layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=330 src=\"https://www.tensorflow.org/images/tutorials/transformer/CausalSelfAttention-new-full.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JWX0RNnFjaCj"
      },
      "source": [
        "Again, the residual connections are omitted for simplicity.\n",
        "\n",
        "The more compact representation of this layer would be:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3C9qVfvh8-jp"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003c/tr\u003e\n",
        "  \u003cth colspan=1\u003eThe causal self-attention layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=430 src=\"https://www.tensorflow.org/images/tutorials/transformer/CausalSelfAttention-new.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uQBhYEZ2jfrX"
      },
      "source": [
        "Test out the layer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x4dQuzvlD99_"
      },
      "outputs": [],
      "source": [
        "sample_csa = CausalSelfAttention(num_heads=2, key_dim=512)\n",
        "\n",
        "print(en_emb.shape)\n",
        "print(sample_csa(en_emb).shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n-IPCEkajleb"
      },
      "source": [
        "The output for early sequence elements doesn't depend on later elements, so it shouldn't matter if you trim elements before or after applying the layer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bwKlheQ-WVxl"
      },
      "outputs": [],
      "source": [
        "out1 = sample_csa(embed_en(en[:, :3])) \n",
        "out2 = sample_csa(embed_en(en))[:, :3]\n",
        "\n",
        "tf.reduce_max(abs(out1 - out2)).numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jOVv38ynuBW-"
      },
      "source": [
        "Note: When using Keras masks, the output values at invalid locations are not well defined. So the above may not hold for masked regions. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nLjScSWQv9M5"
      },
      "source": [
        "### The feed forward network"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pz0HBopX_VdU"
      },
      "source": [
        "The transformer also includes this point-wise feed-forward network in both the encoder and decoder:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bDHMWoZ94AUU"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe feed forward network\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/FeedForward.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Yb-IV0Nlzd0"
      },
      "source": [
        "The network consists of two linear layers (`tf.keras.layers.Dense`) with a ReLU activation in-between, and a dropout layer. As with the attention layers the code here also includes the residual connection and normalization:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rAYLeu0uwXYK"
      },
      "outputs": [],
      "source": [
        "class FeedForward(tf.keras.layers.Layer):\n",
        "  def __init__(self, d_model, dff, dropout_rate=0.1):\n",
        "    super().__init__()\n",
        "    self.seq = tf.keras.Sequential([\n",
        "      tf.keras.layers.Dense(dff, activation='relu'),\n",
        "      tf.keras.layers.Dense(d_model),\n",
        "      tf.keras.layers.Dropout(dropout_rate)\n",
        "    ])\n",
        "    self.add = tf.keras.layers.Add()\n",
        "    self.layer_norm = tf.keras.layers.LayerNormalization()\n",
        "\n",
        "  def call(self, x):\n",
        "    x = self.add([x, self.seq(x)])\n",
        "    x = self.layer_norm(x) \n",
        "    return x\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eQBlOVQU_hUt"
      },
      "source": [
        "Test the layer, the output is the same shape as the input:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r-Y8Yqi1_hUt"
      },
      "outputs": [],
      "source": [
        "sample_ffn = FeedForward(512, 2048)\n",
        "\n",
        "print(en_emb.shape)\n",
        "print(sample_ffn(en_emb).shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QFv-FNYUmvpn"
      },
      "source": [
        "### The encoder layer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zk-DAL2xv4PZ"
      },
      "source": [
        "The encoder contains a stack of `N` encoder layers. Where each `EncoderLayer` contains a `GlobalSelfAttention` and `FeedForward` layer:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RgPaE3f44Cgh"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe encoder layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/EncoderLayer.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8kRUT__Ly9HH"
      },
      "source": [
        "Here is the definition of the `EncoderLayer`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ncyS-Ms3i2x_"
      },
      "outputs": [],
      "source": [
        "class EncoderLayer(tf.keras.layers.Layer):\n",
        "  def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1):\n",
        "    super().__init__()\n",
        "\n",
        "    self.self_attention = GlobalSelfAttention(\n",
        "        num_heads=num_heads,\n",
        "        key_dim=d_model,\n",
        "        dropout=dropout_rate)\n",
        "\n",
        "    self.ffn = FeedForward(d_model, dff)\n",
        "\n",
        "  def call(self, x):\n",
        "    x = self.self_attention(x)\n",
        "    x = self.ffn(x)\n",
        "    return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QeXHMUlb6q6F"
      },
      "source": [
        "And a quick test, the output will have the same shape as the input:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AzZRXdO0mI48"
      },
      "outputs": [],
      "source": [
        "sample_encoder_layer = EncoderLayer(d_model=512, num_heads=8, dff=2048)\n",
        "\n",
        "print(pt_emb.shape)\n",
        "print(sample_encoder_layer(pt_emb).shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SE1H51Ajm0q1"
      },
      "source": [
        "### The encoder"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fym9ah11ykMd"
      },
      "source": [
        "Next build the encoder."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dXI2B-Ad4ETO"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe encoder\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/Encoder.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DA6sVo5rlzd3"
      },
      "source": [
        "The encoder consists of:\n",
        "\n",
        "- A `PositionalEmbedding` layer at the input.\n",
        "- A stack of `EncoderLayer` layers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jpEox7gJ8FCI"
      },
      "outputs": [],
      "source": [
        "class Encoder(tf.keras.layers.Layer):\n",
        "  def __init__(self, *, num_layers, d_model, num_heads,\n",
        "               dff, vocab_size, dropout_rate=0.1):\n",
        "    super().__init__()\n",
        "\n",
        "    self.d_model = d_model\n",
        "    self.num_layers = num_layers\n",
        "\n",
        "    self.pos_embedding = PositionalEmbedding(\n",
        "        vocab_size=vocab_size, d_model=d_model)\n",
        "\n",
        "    self.enc_layers = [\n",
        "        EncoderLayer(d_model=d_model,\n",
        "                     num_heads=num_heads,\n",
        "                     dff=dff,\n",
        "                     dropout_rate=dropout_rate)\n",
        "        for _ in range(num_layers)]\n",
        "    self.dropout = tf.keras.layers.Dropout(dropout_rate)\n",
        "\n",
        "  def call(self, x):\n",
        "    # `x` is token-IDs shape: (batch, seq_len)\n",
        "    x = self.pos_embedding(x)  # Shape `(batch_size, seq_len, d_model)`.\n",
        "    \n",
        "    # Add dropout.\n",
        "    x = self.dropout(x)\n",
        "\n",
        "    for i in range(self.num_layers):\n",
        "      x = self.enc_layers[i](x)\n",
        "\n",
        "    return x  # Shape `(batch_size, seq_len, d_model)`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "texobMBHLBEU"
      },
      "source": [
        "Test the encoder:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SDPXTvYgJH8s"
      },
      "outputs": [],
      "source": [
        "# Instantiate the encoder.\n",
        "sample_encoder = Encoder(num_layers=4,\n",
        "                         d_model=512,\n",
        "                         num_heads=8,\n",
        "                         dff=2048,\n",
        "                         vocab_size=8500)\n",
        "\n",
        "sample_encoder_output = sample_encoder(pt, training=False)\n",
        "\n",
        "# Print the shape.\n",
        "print(pt.shape)\n",
        "print(sample_encoder_output.shape)  # Shape `(batch_size, input_seq_len, d_model)`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6LO_48Owmx_o"
      },
      "source": [
        "### The decoder layer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GGxm57u6E4g2"
      },
      "source": [
        "The decoder's stack is slightly more complex, with each `DecoderLayer` containing a `CausalSelfAttention`, a `CrossAttention`, and a `FeedForward` layer:  "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1ZYER7rC4FmI"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe decoder layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/DecoderLayer.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9SoX0-vd1hue"
      },
      "outputs": [],
      "source": [
        "class DecoderLayer(tf.keras.layers.Layer):\n",
        "  def __init__(self,\n",
        "               *,\n",
        "               d_model,\n",
        "               num_heads,\n",
        "               dff,\n",
        "               dropout_rate=0.1):\n",
        "    super(DecoderLayer, self).__init__()\n",
        "\n",
        "    self.causal_self_attention = CausalSelfAttention(\n",
        "        num_heads=num_heads,\n",
        "        key_dim=d_model,\n",
        "        dropout=dropout_rate)\n",
        "    \n",
        "    self.cross_attention = CrossAttention(\n",
        "        num_heads=num_heads,\n",
        "        key_dim=d_model,\n",
        "        dropout=dropout_rate)\n",
        "\n",
        "    self.ffn = FeedForward(d_model, dff)\n",
        "\n",
        "  def call(self, x, context):\n",
        "    x = self.causal_self_attention(x=x)\n",
        "    x = self.cross_attention(x=x, context=context)\n",
        "\n",
        "    # Cache the last attention scores for plotting later\n",
        "    self.last_attn_scores = self.cross_attention.last_attn_scores\n",
        "\n",
        "    x = self.ffn(x)  # Shape `(batch_size, seq_len, d_model)`.\n",
        "    return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a6T3RSR_6nJX"
      },
      "source": [
        "Test the decoder layer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ne2Bqx8k71l0"
      },
      "outputs": [],
      "source": [
        "sample_decoder_layer = DecoderLayer(d_model=512, num_heads=8, dff=2048)\n",
        "\n",
        "sample_decoder_layer_output = sample_decoder_layer(\n",
        "    x=en_emb, context=pt_emb)\n",
        "\n",
        "print(en_emb.shape)\n",
        "print(pt_emb.shape)\n",
        "print(sample_decoder_layer_output.shape)  # `(batch_size, seq_len, d_model)`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p-uO6ls8m2O5"
      },
      "source": [
        "### The decoder"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fgj3c0TVF3Pb"
      },
      "source": [
        "Similar to the `Encoder`, the `Decoder` consists of a `PositionalEmbedding`, and a stack of `DecoderLayer`s:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ADGss2nT4Gt-"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe embedding and positional encoding layer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/Decoder.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2q49vtv5lzd3"
      },
      "source": [
        "\n",
        "Define the decoder by extending `tf.keras.layers.Layer`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d5_d5-PLQXwY"
      },
      "outputs": [],
      "source": [
        "class Decoder(tf.keras.layers.Layer):\n",
        "  def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,\n",
        "               dropout_rate=0.1):\n",
        "    super(Decoder, self).__init__()\n",
        "\n",
        "    self.d_model = d_model\n",
        "    self.num_layers = num_layers\n",
        "\n",
        "    self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,\n",
        "                                             d_model=d_model)\n",
        "    self.dropout = tf.keras.layers.Dropout(dropout_rate)\n",
        "    self.dec_layers = [\n",
        "        DecoderLayer(d_model=d_model, num_heads=num_heads,\n",
        "                     dff=dff, dropout_rate=dropout_rate)\n",
        "        for _ in range(num_layers)]\n",
        "\n",
        "    self.last_attn_scores = None\n",
        "\n",
        "  def call(self, x, context):\n",
        "    # `x` is token-IDs shape (batch, target_seq_len)\n",
        "    x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)\n",
        "\n",
        "    x = self.dropout(x)\n",
        "\n",
        "    for i in range(self.num_layers):\n",
        "      x  = self.dec_layers[i](x, context)\n",
        "\n",
        "    self.last_attn_scores = self.dec_layers[-1].last_attn_scores\n",
        "\n",
        "    # The shape of x is (batch_size, target_seq_len, d_model).\n",
        "    return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eALcB--YMmLf"
      },
      "source": [
        "Test the decoder:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xyHdG_jWPgKu"
      },
      "outputs": [],
      "source": [
        "# Instantiate the decoder.\n",
        "sample_decoder = Decoder(num_layers=4,\n",
        "                         d_model=512,\n",
        "                         num_heads=8,\n",
        "                         dff=2048,\n",
        "                         vocab_size=8000)\n",
        "\n",
        "output = sample_decoder(\n",
        "    x=en,\n",
        "    context=pt_emb)\n",
        "\n",
        "# Print the shapes.\n",
        "print(en.shape)\n",
        "print(pt_emb.shape)\n",
        "print(output.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ioJ4XJAUAReI"
      },
      "outputs": [],
      "source": [
        "sample_decoder.last_attn_scores.shape  # (batch, heads, target_seq, input_seq)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D3uvMP5vNuOV"
      },
      "source": [
        "Having created the Transformer encoder and decoder, it's time to build the Transformer model and train it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y54xnJnuYgJ7"
      },
      "source": [
        "## The Transformer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PSi8vBN1lzd4"
      },
      "source": [
        "You now have `Encoder` and `Decoder`. To complete the `Transformer` model, you need to put them together and add a final linear (`Dense`) layer which converts the resulting vector at each location into output token probabilities. \n",
        "\n",
        "The output of the decoder is the input to this final linear layer."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "46nL2X_84Iud"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe transformer\u003c/th\u003e\n",
        "\u003ctr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg src=\"https://www.tensorflow.org/images/tutorials/transformer/transformer.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "trHHo2z_LC-u"
      },
      "source": [
        "A `Transformer` with one layer in both the `Encoder` and `Decoder` looks almost exactly like the model from the [RNN+attention tutorial](https://www.tensorflow.org/text/tutorials/nmt_with_attention). A multi-layer Transformer has more layers, but is fundamentally doing the same thing."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "09kxrwiaBB36"
      },
      "source": [
        "\u003ctable\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eA 1-layer transformer\u003c/th\u003e\n",
        "  \u003cth colspan=1\u003eA 4-layer transformer\u003c/th\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=400 src=\"https://www.tensorflow.org/images/tutorials/transformer/Transformer-1layer-compact.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd rowspan=3\u003e\n",
        "   \u003cimg width=330 src=\"https://www.tensorflow.org/images/tutorials/transformer/Transformer-4layer-compact.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003cth colspan=1\u003eThe RNN+Attention model\u003c/th\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003ctr\u003e\n",
        "  \u003ctd\u003e\n",
        "   \u003cimg width=400 src=\"https://www.tensorflow.org/images/tutorials/transformer/RNN+attention-compact.png\"/\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/tr\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I5vSbJ_gKx7C"
      },
      "source": [
        "Create the `Transformer` by extending `tf.keras.Model`:\n",
        "\n",
        "\u003e Note: The [original paper](https://arxiv.org/pdf/1706.03762.pdf), section 3.4, shares the weight matrix between the embedding layer and the final linear layer. To keep things simple, this tutorial uses two separate weight matrices."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PED3bIpOYkBu"
      },
      "outputs": [],
      "source": [
        "class Transformer(tf.keras.Model):\n",
        "  def __init__(self, *, num_layers, d_model, num_heads, dff,\n",
        "               input_vocab_size, target_vocab_size, dropout_rate=0.1):\n",
        "    super().__init__()\n",
        "    self.encoder = Encoder(num_layers=num_layers, d_model=d_model,\n",
        "                           num_heads=num_heads, dff=dff,\n",
        "                           vocab_size=input_vocab_size,\n",
        "                           dropout_rate=dropout_rate)\n",
        "\n",
        "    self.decoder = Decoder(num_layers=num_layers, d_model=d_model,\n",
        "                           num_heads=num_heads, dff=dff,\n",
        "                           vocab_size=target_vocab_size,\n",
        "                           dropout_rate=dropout_rate)\n",
        "\n",
        "    self.final_layer = tf.keras.layers.Dense(target_vocab_size)\n",
        "\n",
        "  def call(self, inputs):\n",
        "    # To use a Keras model with `.fit` you must pass all your inputs in the\n",
        "    # first argument.\n",
        "    context, x  = inputs\n",
        "\n",
        "    context = self.encoder(context)  # (batch_size, context_len, d_model)\n",
        "\n",
        "    x = self.decoder(x, context)  # (batch_size, target_len, d_model)\n",
        "\n",
        "    # Final linear layer output.\n",
        "    logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)\n",
        "\n",
        "    try:\n",
        "      # Drop the keras mask, so it doesn't scale the losses/metrics.\n",
        "      # b/250038731\n",
        "      del logits._keras_mask\n",
        "    except AttributeError:\n",
        "      pass\n",
        "\n",
        "    # Return the final output and the attention weights.\n",
        "    return logits"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wsINyf1VEQLC"
      },
      "source": [
        "### Hyperparameters"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IjwMq_ixlzd5"
      },
      "source": [
        "To keep this example small and relatively fast, the number of layers (`num_layers`), the dimensionality of the embeddings (`d_model`), and the internal dimensionality of the `FeedForward` layer (`dff`) have been reduced.\n",
        "\n",
        "The base model described in the original Transformer paper used `num_layers=6`, `d_model=512`, and `dff=2048`.\n",
        "\n",
        "The number of self-attention heads remains the same (`num_heads=8`).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mzyo6KDfVyhl"
      },
      "outputs": [],
      "source": [
        "num_layers = 4\n",
        "d_model = 128\n",
        "dff = 512\n",
        "num_heads = 8\n",
        "dropout_rate = 0.1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g08YOE-zHRqY"
      },
      "source": [
        "### Try it out"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yYbXDEhhlzd6"
      },
      "source": [
        "Instantiate the `Transformer` model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UiysUa--4tOU"
      },
      "outputs": [],
      "source": [
        "transformer = Transformer(\n",
        "    num_layers=num_layers,\n",
        "    d_model=d_model,\n",
        "    num_heads=num_heads,\n",
        "    dff=dff,\n",
        "    input_vocab_size=tokenizers.pt.get_vocab_size().numpy(),\n",
        "    target_vocab_size=tokenizers.en.get_vocab_size().numpy(),\n",
        "    dropout_rate=dropout_rate)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qbw3CYn2tQQx"
      },
      "source": [
        "Test it:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c8eO85hpFHmE"
      },
      "outputs": [],
      "source": [
        "output = transformer((pt, en))\n",
        "\n",
        "print(en.shape)\n",
        "print(pt.shape)\n",
        "print(output.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "olTLrK8pAcLd"
      },
      "outputs": [],
      "source": [
        "attn_scores = transformer.decoder.dec_layers[-1].last_attn_scores\n",
        "print(attn_scores.shape)  # (batch, heads, target_seq, input_seq)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_jTvJsXquaHW"
      },
      "source": [
        "Print the summary of the model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IsUPhlfEtOjn"
      },
      "outputs": [],
      "source": [
        "transformer.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EfoBfC2oQtEy"
      },
      "source": [
        "## Training\n",
        "\n",
        "It's time to prepare the model and start training it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xYEGhEOtzn5W"
      },
      "source": [
        "### Set up the optimizer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SL4G5bS6lzd5"
      },
      "source": [
        "Use the Adam optimizer with a custom learning rate scheduler according to the formula in the original Transformer [paper](https://arxiv.org/abs/1706.03762).\n",
        "\n",
        "$$\\Large{lrate = d_{model}^{-0.5} * \\min(step{\\_}num^{-0.5}, step{\\_}num \\cdot warmup{\\_}steps^{-1.5})}$$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iYQdOO1axwEI"
      },
      "outputs": [],
      "source": [
        "class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):\n",
        "  def __init__(self, d_model, warmup_steps=4000):\n",
        "    super().__init__()\n",
        "\n",
        "    self.d_model = d_model\n",
        "    self.d_model = tf.cast(self.d_model, tf.float32)\n",
        "\n",
        "    self.warmup_steps = warmup_steps\n",
        "\n",
        "  def __call__(self, step):\n",
        "    step = tf.cast(step, dtype=tf.float32)\n",
        "    arg1 = tf.math.rsqrt(step)\n",
        "    arg2 = step * (self.warmup_steps ** -1.5)\n",
        "\n",
        "    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fzXq5LWgRN63"
      },
      "source": [
        "Instantiate the optimizer (in this example it's `tf.keras.optimizers.Adam`):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7r4scdulztRx"
      },
      "outputs": [],
      "source": [
        "learning_rate = CustomSchedule(d_model)\n",
        "\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,\n",
        "                                     epsilon=1e-9)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fTb2S4RnQ8DU"
      },
      "source": [
        "Test the custom learning rate scheduler:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xij3MwYVRAAS"
      },
      "outputs": [],
      "source": [
        "plt.plot(learning_rate(tf.range(40000, dtype=tf.float32)))\n",
        "plt.ylabel('Learning Rate')\n",
        "plt.xlabel('Train Step')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YgkDE7hzo8r5"
      },
      "source": [
        "### Set up the loss and metrics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B6y7rNP5lzd6"
      },
      "source": [
        "Since the target sequences are padded, it is important to apply a padding mask when calculating the loss. Use the cross-entropy loss function (`tf.keras.losses.SparseCategoricalCrossentropy`):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "67oqVHiT0Eiu"
      },
      "outputs": [],
      "source": [
        "def masked_loss(label, pred):\n",
        "  mask = label != 0\n",
        "  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
        "    from_logits=True, reduction='none')\n",
        "  loss = loss_object(label, pred)\n",
        "\n",
        "  mask = tf.cast(mask, dtype=loss.dtype)\n",
        "  loss *= mask\n",
        "\n",
        "  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)\n",
        "  return loss\n",
        "\n",
        "\n",
        "def masked_accuracy(label, pred):\n",
        "  pred = tf.argmax(pred, axis=2)\n",
        "  label = tf.cast(label, pred.dtype)\n",
        "  match = label == pred\n",
        "\n",
        "  mask = label != 0\n",
        "\n",
        "  match = match \u0026 mask\n",
        "\n",
        "  match = tf.cast(match, dtype=tf.float32)\n",
        "  mask = tf.cast(mask, dtype=tf.float32)\n",
        "  return tf.reduce_sum(match)/tf.reduce_sum(mask)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xYEasEOsdn5W"
      },
      "source": [
        "### Train the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mk8vwuN24hafK"
      },
      "source": [
        "With all the components ready, configure the training procedure using `model.compile`, and then run it with `model.fit`:\n",
        "\n",
        "Note: This takes about an hour to train in Colab."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Una1v0hDlIsT"
      },
      "outputs": [],
      "source": [
        "transformer.compile(\n",
        "    loss=masked_loss,\n",
        "    optimizer=optimizer,\n",
        "    metrics=[masked_accuracy])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jg35qKvVlctp"
      },
      "outputs": [],
      "source": [
        "transformer.fit(train_batches,\n",
        "                epochs=20,\n",
        "                validation_data=val_batches)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cxKpqCbzSW6z"
      },
      "source": [
        "## Run inference"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mk8vwuN1SafK"
      },
      "source": [
        "You can now test the model by performing a translation. The following steps are used for inference:\n",
        "\n",
        "* Encode the input sentence using the Portuguese tokenizer (`tokenizers.pt`). This is the encoder input.\n",
        "* The decoder input is initialized to the `[START]` token.\n",
        "* Calculate the padding masks and the look ahead masks.\n",
        "* The `decoder` then outputs the predictions by looking at the `encoder output` and its own output (self-attention).\n",
        "* Concatenate the predicted token to the decoder input and pass it to the decoder.\n",
        "* In this approach, the decoder predicts the next token based on the previous tokens it predicted.\n",
        "\n",
        "Note: The model is optimized for _efficient training_ and makes a next-token prediction for each token in the output simultaneously. This is redundant during inference, and only the last prediction is used.  This model can be made more efficient for inference if you only calculate the last prediction when running in inference mode (`training=False`).\n",
        "\n",
        "Define the `Translator` class by subclassing `tf.Module`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eY_uXsOhSmbb"
      },
      "outputs": [],
      "source": [
        "class Translator(tf.Module):\n",
        "  def __init__(self, tokenizers, transformer):\n",
        "    self.tokenizers = tokenizers\n",
        "    self.transformer = transformer\n",
        "\n",
        "  def __call__(self, sentence, max_length=MAX_TOKENS):\n",
        "    # The input sentence is Portuguese, hence adding the `[START]` and `[END]` tokens.\n",
        "    assert isinstance(sentence, tf.Tensor)\n",
        "    if len(sentence.shape) == 0:\n",
        "      sentence = sentence[tf.newaxis]\n",
        "\n",
        "    sentence = self.tokenizers.pt.tokenize(sentence).to_tensor()\n",
        "\n",
        "    encoder_input = sentence\n",
        "\n",
        "    # As the output language is English, initialize the output with the\n",
        "    # English `[START]` token.\n",
        "    start_end = self.tokenizers.en.tokenize([''])[0]\n",
        "    start = start_end[0][tf.newaxis]\n",
        "    end = start_end[1][tf.newaxis]\n",
        "\n",
        "    # `tf.TensorArray` is required here (instead of a Python list), so that the\n",
        "    # dynamic-loop can be traced by `tf.function`.\n",
        "    output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)\n",
        "    output_array = output_array.write(0, start)\n",
        "\n",
        "    for i in tf.range(max_length):\n",
        "      output = tf.transpose(output_array.stack())\n",
        "      predictions = self.transformer([encoder_input, output], training=False)\n",
        "\n",
        "      # Select the last token from the `seq_len` dimension.\n",
        "      predictions = predictions[:, -1:, :]  # Shape `(batch_size, 1, vocab_size)`.\n",
        "\n",
        "      predicted_id = tf.argmax(predictions, axis=-1)\n",
        "\n",
        "      # Concatenate the `predicted_id` to the output which is given to the\n",
        "      # decoder as its input.\n",
        "      output_array = output_array.write(i+1, predicted_id[0])\n",
        "\n",
        "      if predicted_id == end:\n",
        "        break\n",
        "\n",
        "    output = tf.transpose(output_array.stack())\n",
        "    # The output shape is `(1, tokens)`.\n",
        "    text = tokenizers.en.detokenize(output)[0]  # Shape: `()`.\n",
        "\n",
        "    tokens = tokenizers.en.lookup(output)[0]\n",
        "\n",
        "    # `tf.function` prevents us from using the attention_weights that were\n",
        "    # calculated on the last iteration of the loop.\n",
        "    # So, recalculate them outside the loop.\n",
        "    self.transformer([encoder_input, output[:,:-1]], training=False)\n",
        "    attention_weights = self.transformer.decoder.last_attn_scores\n",
        "\n",
        "    return text, tokens, attention_weights"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mJ3o-65iS6CN"
      },
      "source": [
        "Note: This function uses an unrolled loop, not a dynamic loop. It generates `MAX_TOKENS` on every call. Refer to the [NMT with attention](nmt_with_attention.ipynb) tutorial for an example implementation with a dynamic loop, which can be much more efficient."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TeUJafisS435"
      },
      "source": [
        "Create an instance of this `Translator` class, and try it out a few times:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-NjbvpHUTEia"
      },
      "outputs": [],
      "source": [
        "translator = Translator(tokenizers, transformer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QfHSRdejTFsC"
      },
      "outputs": [],
      "source": [
        "def print_translation(sentence, tokens, ground_truth):\n",
        "  print(f'{\"Input:\":15s}: {sentence}')\n",
        "  print(f'{\"Prediction\":15s}: {tokens.numpy().decode(\"utf-8\")}')\n",
        "  print(f'{\"Ground truth\":15s}: {ground_truth}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "buUeDo58TIoD"
      },
      "source": [
        "Example 1:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o9CEm4cuTGtw"
      },
      "outputs": [],
      "source": [
        "sentence = 'este é um problema que temos que resolver.'\n",
        "ground_truth = 'this is a problem we have to solve .'\n",
        "\n",
        "translated_text, translated_tokens, attention_weights = translator(\n",
        "    tf.constant(sentence))\n",
        "print_translation(sentence, translated_text, ground_truth)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sfJrFBZ6TJxc"
      },
      "source": [
        "Example 2:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "elmz_Ly7THuJ"
      },
      "outputs": [],
      "source": [
        "sentence = 'os meus vizinhos ouviram sobre esta ideia.'\n",
        "ground_truth = 'and my neighboring homes heard about this idea .'\n",
        "\n",
        "translated_text, translated_tokens, attention_weights = translator(\n",
        "    tf.constant(sentence))\n",
        "print_translation(sentence, translated_text, ground_truth)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EY7NfEjrTOCr"
      },
      "source": [
        "Example 3:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bmmtPo3vTOwj"
      },
      "outputs": [],
      "source": [
        "sentence = 'vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.'\n",
        "ground_truth = \"so i'll just share with you some stories very quickly of some magical things that have happened.\"\n",
        "\n",
        "translated_text, translated_tokens, attention_weights = translator(\n",
        "    tf.constant(sentence))\n",
        "print_translation(sentence, translated_text, ground_truth)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aB_03k0kTQLb"
      },
      "source": [
        "## Create attention plots"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "miZXl9i-TSs6"
      },
      "source": [
        "The `Translator` class you created in the previous section returns a dictionary of attention heatmaps you can use to visualize the internal working of the model.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V3m2wcNLTU8K"
      },
      "outputs": [],
      "source": [
        "sentence = 'este é o primeiro livro que eu fiz.'\n",
        "ground_truth = \"this is the first book i've ever done.\"\n",
        "\n",
        "translated_text, translated_tokens, attention_weights = translator(\n",
        "    tf.constant(sentence))\n",
        "print_translation(sentence, translated_text, ground_truth)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-rhE_LW7TZ40"
      },
      "source": [
        "Create a function that plots the attention when a token is generated:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gKlxYO0JTXzD"
      },
      "outputs": [],
      "source": [
        "def plot_attention_head(in_tokens, translated_tokens, attention):\n",
        "  # The model didn't generate `\u003cSTART\u003e` in the output. Skip it.\n",
        "  translated_tokens = translated_tokens[1:]\n",
        "\n",
        "  ax = plt.gca()\n",
        "  ax.matshow(attention)\n",
        "  ax.set_xticks(range(len(in_tokens)))\n",
        "  ax.set_yticks(range(len(translated_tokens)))\n",
        "\n",
        "  labels = [label.decode('utf-8') for label in in_tokens.numpy()]\n",
        "  ax.set_xticklabels(\n",
        "      labels, rotation=90)\n",
        "\n",
        "  labels = [label.decode('utf-8') for label in translated_tokens.numpy()]\n",
        "  ax.set_yticklabels(labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yI4YWU2uXDeW"
      },
      "outputs": [],
      "source": [
        "head = 0\n",
        "# Shape: `(batch=1, num_heads, seq_len_q, seq_len_k)`.\n",
        "attention_heads = tf.squeeze(attention_weights, 0)\n",
        "attention = attention_heads[head]\n",
        "attention.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "facNouzOXMSu"
      },
      "source": [
        "These are the input (Portuguese) tokens:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SMEpyioWTmSN"
      },
      "outputs": [],
      "source": [
        "in_tokens = tf.convert_to_tensor([sentence])\n",
        "in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor()\n",
        "in_tokens = tokenizers.pt.lookup(in_tokens)[0]\n",
        "in_tokens"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JLg9HTKCXPKz"
      },
      "source": [
        "And these are the output (English translation) tokens:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GzvIo5uYTnHG"
      },
      "outputs": [],
      "source": [
        "translated_tokens"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lrNh47D1ToBD"
      },
      "outputs": [],
      "source": [
        "plot_attention_head(in_tokens, translated_tokens, attention)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iMZr-rI_TrGh"
      },
      "outputs": [],
      "source": [
        "def plot_attention_weights(sentence, translated_tokens, attention_heads):\n",
        "  in_tokens = tf.convert_to_tensor([sentence])\n",
        "  in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor()\n",
        "  in_tokens = tokenizers.pt.lookup(in_tokens)[0]\n",
        "\n",
        "  fig = plt.figure(figsize=(16, 8))\n",
        "\n",
        "  for h, head in enumerate(attention_heads):\n",
        "    ax = fig.add_subplot(2, 4, h+1)\n",
        "\n",
        "    plot_attention_head(in_tokens, translated_tokens, head)\n",
        "\n",
        "    ax.set_xlabel(f'Head {h+1}')\n",
        "\n",
        "  plt.tight_layout()\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lBMujUb1Tr4C"
      },
      "outputs": [],
      "source": [
        "plot_attention_weights(sentence,\n",
        "                       translated_tokens,\n",
        "                       attention_weights[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9N5S5IptTtHI"
      },
      "source": [
        "The model can handle unfamiliar words. Neither `'triceratops'` nor `'encyclopédia'` are in the input dataset, and the model attempts to transliterate them even without a shared vocabulary. For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w0-5gjfWT0CS"
      },
      "outputs": [],
      "source": [
        "sentence = 'Eu li sobre triceratops na enciclopédia.'\n",
        "ground_truth = 'I read about triceratops in the encyclopedia.'\n",
        "\n",
        "translated_text, translated_tokens, attention_weights = translator(\n",
        "    tf.constant(sentence))\n",
        "print_translation(sentence, translated_text, ground_truth)\n",
        "\n",
        "plot_attention_weights(sentence, translated_tokens, attention_weights[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9zz4uIDbT1OU"
      },
      "source": [
        "## Export the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zunHPJJzT4Cz"
      },
      "source": [
        "You have tested the model and the inference is working. Next, you can export it as a `tf.saved_model`. To learn about saving and loading a model in the SavedModel format, use [this guide](https://www.tensorflow.org/guide/saved_model).\n",
        "\n",
        "Create a class called `ExportTranslator` by subclassing the `tf.Module` subclass with a `tf.function` on the `__call__` method:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NZhv5h4AT_n5"
      },
      "outputs": [],
      "source": [
        "class ExportTranslator(tf.Module):\n",
        "  def __init__(self, translator):\n",
        "    self.translator = translator\n",
        "\n",
        "  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])\n",
        "  def __call__(self, sentence):\n",
        "    (result,\n",
        "     tokens,\n",
        "     attention_weights) = self.translator(sentence, max_length=MAX_TOKENS)\n",
        "\n",
        "    return result"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wad7lUtPUAnf"
      },
      "source": [
        "In the above `tf.function` only the output sentence is returned. Thanks to the [non-strict execution](https://tensorflow.org/guide/intro_to_graphs) in `tf.function` any unnecessary values are never computed."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-7KJEFWI5v84"
      },
      "source": [
        "Wrap `translator` in the newly created `ExportTranslator`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wm1_eRPvUCUm"
      },
      "outputs": [],
      "source": [
        "translator = ExportTranslator(translator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7VPH4T5XUDnc"
      },
      "source": [
        "Since the model is decoding the predictions using `tf.argmax` the predictions are deterministic. The original model and one reloaded from its `SavedModel` should give identical predictions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GITRCiAYUE5w"
      },
      "outputs": [],
      "source": [
        "translator('este é o primeiro livro que eu fiz.').numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_v--e1XmUFw3"
      },
      "outputs": [],
      "source": [
        "tf.saved_model.save(translator, export_dir='translator')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5KJSQEzlUGo-"
      },
      "outputs": [],
      "source": [
        "reloaded = tf.saved_model.load('translator')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lIVpKWBNUHhr"
      },
      "outputs": [],
      "source": [
        "reloaded('este é o primeiro livro que eu fiz.').numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ri2i6cTxUI00"
      },
      "source": [
        "## Conclusion\n",
        "\n",
        "In this tutorial you learned about:\n",
        "\n",
        "* The Transformers and their significance in machine learning\n",
        "* Attention, self-attention and multi-head attention\n",
        "* Positional encoding with embeddings\n",
        "* The encoder-decoder architecture of the original Transformer\n",
        "* Masking in self-attention\n",
        "* How to put it all together to translate text\n",
        "\n",
        "The downsides of this architecture are:\n",
        "\n",
        "- For a time-series, the output for a time-step is calculated from the *entire history* instead of only the inputs and current hidden-state. This _may_ be less efficient.\n",
        "- If the input has a temporal/spatial relationship, like text or images, some positional encoding must be added or the model will effectively see a bag of words.\n",
        "\n",
        "If you want to practice, there are many things you could try with it. For example:\n",
        "\n",
        "* Use a different dataset to train the Transformer.\n",
        "* Create the \"Base Transformer\" or \"Transformer XL\" configurations from the original paper by changing the hyperparameters.\n",
        "* Use the layers defined here to create an implementation of [BERT](https://arxiv.org/abs/1810.04805)\n",
        "* Use Beam search to get better predictions.\n",
        "\n",
        "There are a wide variety of Transformer-based models, many of which improve upon the 2017 version of the original Transformer with encoder-decoder, encoder-only and decoder-only architectures.\n",
        "\n",
        "Some of these models are covered in the following research publications:\n",
        "\n",
        "* [\"Efficient Transformers: a survey\"](https://arxiv.org/abs/2009.06732) (Tay et al., 2022)\n",
        "* [\"Formal algorithms for Transformers\"](https://arxiv.org/abs/2207.09238) (Phuong and Hutter, 2022).\n",
        "* [T5 (\"Exploring the limits of transfer learning with a unified text-to-text Transformer\")](https://arxiv.org/abs/1910.10683) (Raffel et al., 2019)\n",
        "\n",
        "You can learn more about other models in the following Google blog posts:\n",
        "\n",
        "* [PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html).\n",
        "* [LaMDA](https://ai.googleblog.com/2022/01/lamda-towards-safe-grounded-and-high.html)\n",
        "* [MUM](https://blog.google/products/search/introducing-mum/)\n",
        "* [Reformer](https://ai.googleblog.com/2020/01/reformer-efficient-transformer.html)\n",
        "* [BERT](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html)\n",
        "\n",
        "If you're interested in studying how attention-based models have been applied in tasks outside of natural language processing, check out the following resources:\n",
        "\n",
        "- Vision Transformer (ViT): [Transformers for image recognition at scale](https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html)\n",
        "- [Multi-task multitrack music transcription (MT3)](https://magenta.tensorflow.org/transcription-with-transformers) with a Transformer\n",
        "- [Code generation with AlphaCode](https://www.deepmind.com/blog/competitive-programming-with-alphacode)\n",
        "- [Reinforcement learning with multi-game decision Transformers](https://ai.googleblog.com/2022/07/training-generalist-agents-with-multi.html)\n",
        "- [Protein structure prediction with AlphaFold](https://www.nature.com/articles/s41586-021-03819-2)\n",
        "- [OptFormer: Towards universal hyperparameter optimization with Transformers](http://ai.googleblog.com/2022/08/optformer-towards-universal.html)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Kk6IeFbP0ei"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
